package org.hhf.rpt.service;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.hhf.rpt.exception.BizException;
import org.hhf.rpt.exception.CoreAssert;
import org.hhf.rpt.util.StringUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.common.CompositeStringExpression;
import org.springframework.expression.common.TemplateParserContext;
import org.springframework.expression.spel.standard.SpelExpression;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;

import java.io.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * @author haohaifeng
 * @date 2021/1/15 10:36
 */
@Slf4j
public class SqlParser {
    private SqlXml sqlXml;

    private final static String PREFIX = "rpt/";
    private final static String SUFFIX = ".xml";

    public SqlParser(String rptCode) {
        this.sqlXml = parseXml(rptCode, null);
    }

    public SqlParser(String rptCode, String path) {
        this.sqlXml = parseXml(rptCode, path);
    }

    @Data
    private static class SqlXml {
        private String rptCode;
        private String rptName;
        private String colNames;
        private String show;
        private String sql;
    }

    public SqlXml parseXml(String rptCode, String path) {
        InputStream inputStream = getXmlFileStream(rptCode, path);
        SAXReader saxReader = new SAXReader();
        Document doc = null;
        try {

            doc = saxReader.read(inputStream);
        } catch (DocumentException e) {
            log.error(e.getMessage(), e);
            throw new BizException("100100", "文件流读取异常");
        }
        Element root = doc.getRootElement();
        CoreAssert.notNull(root, "rptCode不存在");
        SqlXml xml = new SqlXml();
        xml.setShow(root.attribute("show").getValue());
        xml.setSql(root.elementText("sql"));
        xml.setColNames(root.elementText("colNames"));
        xml.setRptName(root.elementText("rptName"));
        xml.setRptCode(rptCode);
        return xml;
    }

    private InputStream getXmlFileStream(String rptCode, String absolutePath) {
        InputStream inputStream = null;
        if (StringUtils.isBlank(absolutePath)) {
            ClassPathResource classPathResource = new ClassPathResource(PREFIX + rptCode + SUFFIX);
            try {
                inputStream = classPathResource.getInputStream();
            } catch (IOException e) {
                log.error(e.getMessage(), e);
                throw new BizException("100101", "获取报表文件异常");
            }
        } else {
            String dir = absolutePath.endsWith(File.separator) ? absolutePath : absolutePath + File.separator;
            File dirFile = new File(dir);
            if (!dirFile.isDirectory()) {
                log.error("xml文件目录路径维护异常");
                throw new BizException("100102", "XML文件目录维护异常，请配置参数：rpt.xml.dir.absPath");
            }
            String xmlFilePath = dir + rptCode + SUFFIX;
            try {
                inputStream = new FileInputStream(new File(xmlFilePath));
            } catch (FileNotFoundException e) {
                log.error(e.getMessage(), e);
                throw new BizException("100103", new Object[]{xmlFilePath});
            }
        }
        return inputStream;
    }

    /**
     * 获取SQL主体.
     *
     * @return
     */
    public String getSQL() {
        String sql = sqlXml.getSql();
        String show = sqlXml.getShow();
        String rptCode = sqlXml.getRptCode();
        if ("true".equals(show)) {
            log.info("{} : {}", rptCode, sql);
        }
        return sql;
    }

    /**
     * 解析参数替换符后获取带参SQL.
     *
     * @param obj 参数接收对象
     * @return 解析拼接参数后的sql
     */
    public String parseSQL(Object obj) {
        String sql = sqlXml.getSql();
        String show = sqlXml.getShow();
        String rptCode = sqlXml.getRptCode();
        try {
            sql = parseSql(sql, obj);
            if ("true".equals(show)) {
                log.info("{} : {}", rptCode, sql);
            }
        } catch (IllegalAccessException e) {
            log.error(e.getMessage(), e);
            throw new BizException("100104", "参数异常");
        }
        return sql;
    }

    public String getRptName() {
        return sqlXml.getRptName();
    }

    /**
     * 获取导出数据表头.
     *
     * @return
     */
    public String[] getTitleCols() {
        String colNames = sqlXml.getColNames();
        String show = sqlXml.getShow();
        if ("true".equals(show)) {
            log.info(colNames);
        }
        if (StringUtils.isBlank(colNames)) {
            log.error("导出的表头没有配置");
            throw new BizException("100105", "导出的表头没有配置");
        }
        String[] titleCols = colNames.replaceAll("\\r|\\n", "").trim().split("\\s*(,|;)\\s*");
        return titleCols;
    }

    private static String parseSql(String sqlTemplate, Object obj) throws IllegalAccessException {
        ExpressionParser paser = new SpelExpressionParser();//创建表达式解析器
        //通过evaluationContext.setVariable可以在上下文中设定变量。
        EvaluationContext context = new StandardEvaluationContext();
        if (obj instanceof Map) {
            Map<String, Object> map = (Map) obj;
            String[] sqlRow = sqlTemplate.split("\n");
            List<String> sqlRowList = Arrays.stream(sqlRow).filter(n -> n.trim().startsWith("{") && n.trim().endsWith("}")).collect(Collectors.toList());
            for (String row : sqlRowList) {
                row = row.trim();
                Expression expression = paser.parseExpression(row, new TemplateParserContext());
                SpelExpression spelExpression = (SpelExpression) ((CompositeStringExpression) expression).getExpressions()[1];
                String key = spelExpression.getExpressionString().replace("#", "");
                Object value = map.get(key);
                if (value != null && StringUtils.isNotBlank(value.toString())) {
                    sqlTemplate = sqlTemplate.replace(row, row.substring(1, row.length() - 1));
                    context.setVariable(key, "'" + value.toString() + "'");
                } else {
                    sqlTemplate = sqlTemplate.replace(row, "");
                }
            }
        } else {
            throw new BizException("100102", "参数传递异常");
        }
        //解析表达式，如果表达式是一个模板表达式，需要为解析传入模板解析器上下文。
        Expression expression = paser.parseExpression(sqlTemplate, new TemplateParserContext());
        //使用Expression.getValue()获取表达式的值，这里传入了Evalution上下文，第二个参数是类型参数，表示返回值的类型。
        String result = expression.getValue(context, String.class);
        return result;
    }

    public static void main(String[] args) {
        String rptCode = "SPE_SURVEY_LIST";
        SqlParser sqlUtils = new SqlParser(rptCode);
        Map params = new HashMap();
        params.put("surveyName", "test");
        String sql1 = sqlUtils.getSQL();
        final String[] split = sql1.split("\n");
        final List<String> list = Arrays.stream(split).filter(n -> n.trim().startsWith("{") && n.trim().endsWith("}")).collect(Collectors.toList());
        list.stream().forEach(System.out::println);
        String sql2 = sqlUtils.parseSQL(params);
        final String[] titleCols = sqlUtils.getTitleCols();
        System.out.println(sql1);
        System.out.println(sql2);
        System.out.println(String.join(",", titleCols));
        System.out.println(sqlUtils.getRptName());
    }
}